from dgl.data import FraudYelpDataset, FraudAmazonDataset
from dgl.data.utils import load_graphs, save_graphs
import dgl
import numpy as np
import torch
import pandas as pd
import os
from tqdm import tqdm
from sklearn.preprocessing import LabelEncoder


class Dataset:
    def __init__(self, name='tfinance', homo=True, anomaly_alpha=None, anomaly_std=None):
        self.name = name
        graph = None
        if name == 'tfinance':
            graph, label_dict = load_graphs('dataset/tfinance')
            graph = graph[0]
            graph.ndata['label'] = graph.ndata['label'].argmax(1)

            if anomaly_std:
                graph, label_dict = load_graphs('dataset/tfinance')
                graph = graph[0]
                feat = graph.ndata['feature'].numpy()
                anomaly_id = graph.ndata['label'][:,1].nonzero().squeeze(1)
                feat = (feat-np.average(feat,0)) / np.std(feat,0)
                feat[anomaly_id] = anomaly_std * feat[anomaly_id]
                graph.ndata['feature'] = torch.tensor(feat)
                graph.ndata['label'] = graph.ndata['label'].argmax(1)

            if anomaly_alpha:
                graph, label_dict = load_graphs('dataset/tfinance')
                graph = graph[0]
                feat = graph.ndata['feature'].numpy()
                anomaly_id = list(graph.ndata['label'][:, 1].nonzero().squeeze(1))
                normal_id = list(graph.ndata['label'][:, 0].nonzero().squeeze(1))
                label = graph.ndata['label'].argmax(1)
                diff = anomaly_alpha * len(label) - len(anomaly_id)
                import random
                new_id = random.sample(normal_id, int(diff))
                # new_id = random.sample(anomaly_id, int(diff))
                for idx in new_id:
                    aid = random.choice(anomaly_id)
                    # aid = random.choice(normal_id)
                    feat[idx] = feat[aid]
                    label[idx] = 1  # 0

        elif name == 'tsocial':
            graph, label_dict = load_graphs('dataset/tsocial')
            graph = graph[0]
            
        elif name == 'ffsd' or name == 'S-FFSD':
            # 先尝试加载已处理好的图
            graph_path = 'dataset/graph-S-FFSD.bin'
            try:
                if os.path.exists(graph_path):
                    print(f"正在加载预处理好的FFSD图：{graph_path}")
                    graph, _ = load_graphs(graph_path)
                    graph = graph[0]
                else:
                    raise FileNotFoundError(f"找不到预处理的图文件：{graph_path}")
            except Exception as e:
                print(f"加载预处理图失败: {e}")
                print("正在处理FFSD数据集并构建图...")
                
                # 处理FFSD数据集
                prefix = './data/'
                try:
                    # 尝试加载增强特征版本
                    df = pd.read_csv(os.path.join(prefix, "S-FFSDneofull.csv"))
                    print("已加载S-FFSDneofull.csv")
                except FileNotFoundError:
                    # 如果增强版不存在，加载原始版本
                    df = pd.read_csv(os.path.join(prefix, "S-FFSD.csv"))
                    print("已加载S-FFSD.csv")
                
                # 清理数据
                df = df.loc[:, ~df.columns.str.contains('Unnamed')]
                data = df[df["Labels"] <= 2]
                data = data.reset_index(drop=True)
                
                # 构建图
                print("构建图结构...")
                alls = []
                allt = []
                pair = ["Source", "Target", "Location", "Type"]
                for column in pair:
                    src, tgt = [], []
                    edge_per_trans = 3
                    for c_id, c_df in tqdm(data.groupby(column), desc=column):
                        c_df = c_df.sort_values(by="Time")
                        df_len = len(c_df)
                        sorted_idxs = c_df.index
                        src.extend([sorted_idxs[i] for i in range(df_len)
                                    for j in range(edge_per_trans) if i + j < df_len])
                        tgt.extend([sorted_idxs[i+j] for i in range(df_len)
                                   for j in range(edge_per_trans) if i + j < df_len])
                    alls.extend(src)
                    allt.extend(tgt)
                
                alls = np.array(alls)
                allt = np.array(allt)
                graph = dgl.graph((alls, allt))
                
                # 处理特征
                cal_list = ["Source", "Target", "Location", "Type"]
                for col in cal_list:
                    le = LabelEncoder()
                    data[col] = le.fit_transform(data[col].apply(str).values)
                
                feat_data = data.drop("Labels", axis=1)
                labels = data["Labels"]
                
                # 添加节点特征和标签
                graph.ndata['feature'] = torch.from_numpy(feat_data.values).float()
                graph.ndata['label'] = torch.from_numpy(labels.values).long()
                
                # 添加自环
                graph = dgl.add_self_loop(graph)
                
                # 保存处理后的图
                os.makedirs('dataset', exist_ok=True)
                save_graphs(graph_path, [graph])
                print(f"FFSD图已保存到: {graph_path}")

        elif name == 'yelp':
            dataset = FraudYelpDataset()
            graph = dataset[0]
            if homo:
                graph = dgl.to_homogeneous(dataset[0], ndata=['feature', 'label', 'train_mask', 'val_mask', 'test_mask'])
                graph = dgl.add_self_loop(graph)
        elif name == 'amazon':
            dataset = FraudAmazonDataset()
            graph = dataset[0]
            if homo:
                graph = dgl.to_homogeneous(dataset[0], ndata=['feature', 'label', 'train_mask', 'val_mask', 'test_mask'])
                graph = dgl.add_self_loop(graph)
        else:
            print('no such dataset')
            exit(1)

        graph.ndata['label'] = graph.ndata['label'].long().squeeze(-1)
        graph.ndata['feature'] = graph.ndata['feature'].float()
        print(graph)

        self.graph = graph
